{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Federated Learning of a Recurrent Neural Network for text classification"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this tutorial, you are going to learn how to train a Recurrent Neural Network (RNN) in a federated way with the purpose of *classifying* a person's surname to its most likely language of origin. \n",
    "\n",
    "\n",
    "We will train two Recurrent Neural Networks residing on two remote workers based on a dataset containing approximately 20.000 surnames from 18 languages of origin, and predict to which language a name belongs based on the name's spelling. \n",
    "\n",
    "A **character-level RNN** treats words as a series of characters - outputting a prediction and “hidden state” per character, feeding its previous hidden state into each next step. We take the final prediction to be the output, i.e. which class the word belongs to. Hence the training process proceeds sequentially character-by-character through the different hidden layers.\n",
    "\n",
    "Following distributed training, we are going to be able to perform predictions of a surname's language of origin, as in the following example:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```python\n",
    "predict(model_pointers[\"bob\"], \"Qing\", alice) #alice is our worker\n",
    "\n",
    " Qing\n",
    "(-1.43) Korean\n",
    "(-1.74) Vietnamese\n",
    "(-2.18) Arabic\n",
    "\n",
    "predict(model_pointers[\"alice\"], \"Daniele\", alice)\n",
    "\n",
    " Daniele\n",
    "(-1.58) French\n",
    "(-2.04) Scottish\n",
    "(-2.07) Dutch\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The present example is inspired by an official Pytorch [tutorial](https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html), which I ported to PySyft with  the purpose of learning a Recurrent Neural Network in a federated way.The present tutorial is self-contained, so there are no dependencies on external pieces of code apart from a few Python libraries.\n",
    "\n",
    "**RNN Tutorial's author**: Daniele Gadler. [@DanyEle](https://github.com/danyele) on Github.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Step: Dependencies!\n",
    "\n",
    "Make sure you have all the requires packages installed, or install them via the following command (assuming you didn't move the current Jupyter Notebook from its initial directory).\n",
    "\n",
    "After installing new packages, you may have to restart this Jupyter Notebook from the tool bar Kernel -> Restart "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !pip install -r \"../../../pip-dep/requirements.txt\"!pip install -r \"../../../pip-dep/requirements_udacity.txt\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import unicode_literals, print_function, division\n",
    "from torch.utils.data import Dataset\n",
    "\n",
    "import torch\n",
    "from io import open\n",
    "import glob\n",
    "import os\n",
    "import numpy as np\n",
    "import unicodedata\n",
    "import string\n",
    "import random\n",
    "import torch.nn as nn\n",
    "import time\n",
    "import math\n",
    "import pandas as pd\n",
    "import random\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.ticker as ticker\n",
    "import urllib.request\n",
    "from zipfile import ZipFile\n",
    "\n",
    "#hide TF-related warnings in PySyft\n",
    "import warnings \n",
    "warnings.filterwarnings(\"ignore\")\n",
    "import syft as sy\n",
    "from syft.frameworks.torch.fl import utils\n",
    "from syft.workers.websocket_client import WebsocketClientWorker\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Step: Data pre-processing and transformation\n",
    "\n",
    "We are going to train our neural network based on a dataset containing surnames from 18 languages of origin. So let's run the following lines to automatically download the dataset and extract it. Afterwards, you'll be able to parse the dataset in Python following the initialization of a few basic functions for parsing the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#create a function for checking if the dataset does indeed exist\n",
    "def dataset_exists():\n",
    "    return (os.path.isfile('./data/eng-fra.txt') and\n",
    "    #check if all 18 files are indeed in the ./data/names/ directory\n",
    "    os.path.isdir('./data/names/') and\n",
    "    os.path.isfile('./data/names/Arabic.txt') and\n",
    "    os.path.isfile('./data/names/Chinese.txt') and\n",
    "    os.path.isfile('./data/names/Czech.txt') and\n",
    "    os.path.isfile('./data/names/Dutch.txt') and\n",
    "    os.path.isfile('./data/names/English.txt') and\n",
    "    os.path.isfile('./data/names/French.txt') and\n",
    "    os.path.isfile('./data/names/German.txt') and\n",
    "    os.path.isfile('./data/names/Greek.txt') and\n",
    "    os.path.isfile('./data/names/Irish.txt') and\n",
    "    os.path.isfile('./data/names/Italian.txt') and\n",
    "    os.path.isfile('./data/names/Japanese.txt') and\n",
    "    os.path.isfile('./data/names/Korean.txt') and\n",
    "    os.path.isfile('./data/names/Polish.txt') and\n",
    "    os.path.isfile('./data/names/Portuguese.txt') and\n",
    "    os.path.isfile('./data/names/Russian.txt') and\n",
    "    os.path.isfile('./data/names/Scottish.txt') and\n",
    "    os.path.isfile('./data/names/Spanish.txt') and\n",
    "    os.path.isfile('./data/names/Vietnamese.txt'))\n",
    "\n",
    "    \n",
    "#If the dataset does not exist, then proceed to download the dataset anew\n",
    "if not dataset_exists():\n",
    "    #If the dataset does not already exist, let's download the dataset directly from the URL where it is hosted\n",
    "    print('Downloading the dataset with urllib2 to the current directory...')\n",
    "    url = 'https://download.pytorch.org/tutorial/data.zip'\n",
    "    urllib.request.urlretrieve(url, './data.zip')\n",
    "    print(\"The dataset was successfully downloaded\")\n",
    "    print(\"Unzipping the dataset...\")\n",
    "    with ZipFile('./data.zip', 'r') as zipObj:\n",
    "       # Extract all the contents of the zip file in current directory\n",
    "       zipObj.extractall()\n",
    "    print(\"Dataset successfully unzipped\")\n",
    "else:\n",
    "    print(\"Not downloading the dataset because it was already downloaded\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Load all the files in a certain path\n",
    "def findFiles(path):\n",
    "    return glob.glob(path)\n",
    "\n",
    "# Read a file and split into lines\n",
    "def readLines(filename):\n",
    "    lines = open(filename, encoding='utf-8').read().strip().split('\\n')\n",
    "    return [unicodeToAscii(line) for line in lines]\n",
    "\n",
    "#convert a string 's' in unicode format to ASCII format\n",
    "def unicodeToAscii(s):\n",
    "    return ''.join(\n",
    "        c for c in unicodedata.normalize('NFD', s)\n",
    "        if unicodedata.category(c) != 'Mn'\n",
    "        and c in all_letters\n",
    "    )\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_letters = string.ascii_letters + \" .,;'\"\n",
    "n_letters = len(all_letters)\n",
    "\n",
    "#dictionary containing the nation as key and the names as values\n",
    "#Example: category_lines[\"italian\"] = [\"Abandonato\",\"Abatangelo\",\"Abatantuono\",...]\n",
    "category_lines = {}\n",
    "#List containing the different categories in the data\n",
    "all_categories = []\n",
    "\n",
    "for filename in findFiles('data/names/*.txt'):\n",
    "    print(filename)\n",
    "    category = os.path.splitext(os.path.basename(filename))[0]\n",
    "    all_categories.append(category)\n",
    "    lines = readLines(filename)\n",
    "    category_lines[category] = lines   \n",
    "    \n",
    "n_categories = len(all_categories)\n",
    "\n",
    "print(\"Amount of categories:\" + str(n_categories))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we are going to format the data so as to make it compliant with the format requested by PySyft and Pytorch. Firstly, we define a dataset class, specifying how batches ought to be extracted from the dataset in order for them to be assigned to the different workers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LanguageDataset(Dataset):\n",
    "    #Constructor is mandatory\n",
    "        def __init__(self, text, labels, transform=None):\n",
    "            self.data = text\n",
    "            self.targets = labels #categories\n",
    "            #self.to_torchtensor()\n",
    "            self.transform = transform\n",
    "        \n",
    "        def to_torchtensor(self):            \n",
    "            self.data = torch.from_numpy(self.text, requires_grad=True)\n",
    "            self.labels = torch.from_numpy(self.targets, requires_grad=True)\n",
    "        \n",
    "        def __len__(self):\n",
    "            #Mandatory\n",
    "            '''Returns:\n",
    "                    Length [int]: Length of Dataset/batches\n",
    "            '''\n",
    "            return len(self.data)\n",
    "    \n",
    "        def __getitem__(self, idx): \n",
    "            #Mandatory \n",
    "            \n",
    "            '''Returns:\n",
    "                     Data [Torch Tensor]: \n",
    "                     Target [ Torch Tensor]:\n",
    "            '''\n",
    "            sample = self.data[idx]\n",
    "            target = self.targets[idx]\n",
    "                    \n",
    "            if self.transform:\n",
    "                sample = self.transform(sample)\n",
    "    \n",
    "            return sample,target"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#The list of arguments for our program. We will be needing most of them soon.\n",
    "class Arguments():\n",
    "    def __init__(self):\n",
    "        self.batch_size = 1\n",
    "        self.learning_rate = 0.005\n",
    "        self.epochs = 10000\n",
    "        self.federate_after_n_batches = 15000\n",
    "        self.seed = 1\n",
    "        self.print_every = 200\n",
    "        self.plot_every = 100\n",
    "        self.use_cuda = False\n",
    "        \n",
    "args = Arguments()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now need to unwrap data samples so as to have them all in one single list instead of a dictionary,where different categories were addressed by key.From now onwards, **categories** will be the languages of origin (Y) and **names** will be the data points (X)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%latex\n",
    "\n",
    "\\begin{split}\n",
    "names\\_list = [d_1,...,d_n]  \\\\\n",
    "\n",
    "category\\_list = [c_1,...,c_n] \n",
    "\\end{split}\n",
    "\n",
    "\n",
    "Where $n$ is the total amount of data points"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Set of names(X)\n",
    "names_list = []\n",
    "#Set of labels (Y)\n",
    "category_list = []\n",
    "\n",
    "#Convert into a list with corresponding label.\n",
    "\n",
    "for nation, names in category_lines.items():\n",
    "    #iterate over every single name\n",
    "    for name in names:\n",
    "        names_list.append(name)      #input data point\n",
    "        category_list.append(nation) #label\n",
    "        \n",
    "#let's see if it was successfully loaded. Each data sample(X) should have its own corresponding category(Y)\n",
    "print(names_list[1:20])\n",
    "print(category_list[1:20])\n",
    "\n",
    "print(\"\\n \\n Amount of data points loaded: \" + str(len(names_list)))\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now need to turn our categories into numbers, as PyTorch cannot really understand plain text\n",
    "\n",
    "For an example category: \"Greek\" ---> 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Assign an integer to every category\n",
    "categories_numerical = pd.factorize(category_list)[0]\n",
    "#Let's wrap our categories with a tensor, so that it can be loaded by LanguageDataset\n",
    "category_tensor = torch.tensor(np.array(categories_numerical), dtype=torch.long)\n",
    "#Ready to be processed by torch.from_numpy in LanguageDataset\n",
    "categories_numpy = np.array(category_tensor)\n",
    "\n",
    "#Let's see a few resulting categories\n",
    "print(names_list[1200:1210])\n",
    "print(categories_numpy[1200:1210])\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now need to turn every single character in each input line string into a vector, with a \"1\" marking the character present in that very character.\n",
    "\n",
    "For example, in the case of a single character, we have:\n",
    "\n",
    "\"a\" = array([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
    "        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
    "        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
    "        0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)\n",
    "\n",
    "\n",
    "A word is just a vector of such character vectors: our Recurrent Neural Network will process every single character vector in the word, producing an output after passing through each of its hidden layers. \n",
    "\n",
    "This technique, involving the encoding of a word as a vector of character vectors, is known as *word embedding*, as we embed a word into a vector of vectors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def letterToIndex(letter):\n",
    "    return all_letters.find(letter)\n",
    "    \n",
    "# Just for demonstration, turn a letter into a <1 x n_letters> Tensor\n",
    "def letterToTensor(letter):\n",
    "    tensor = torch.zeros(1, n_letters)\n",
    "    tensor[0][letterToIndex(letter)] = 1\n",
    "    return tensor\n",
    "\n",
    "# Turn a line into a <line_length x 1 x n_letters>,\n",
    "# or an array of one-hot letter vectors\n",
    "def lineToTensor(line):\n",
    "    tensor = torch.zeros(len(line), 1, n_letters) #Daniele: len(max_line_size) was len(line)\n",
    "    for li, letter in enumerate(line):\n",
    "        tensor[li][0][letterToIndex(letter)] = 1\n",
    "    #Daniele: add blank elements over here\n",
    "    return tensor    \n",
    "    \n",
    "    \n",
    "    \n",
    "def list_strings_to_list_tensors(names_list):\n",
    "    lines_tensors = []\n",
    "    for index, line in enumerate(names_list):\n",
    "        lineTensor = lineToTensor(line)\n",
    "        lineNumpy = lineTensor.numpy()\n",
    "        lines_tensors.append(lineNumpy)\n",
    "        \n",
    "    return(lines_tensors)\n",
    "\n",
    "lines_tensors = list_strings_to_list_tensors(names_list)\n",
    "\n",
    "print(names_list[0])\n",
    "print(lines_tensors[0])\n",
    "print(lines_tensors[0].shape)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's now identify the longest word in the dataset, as all tensors need to have the same shape  in order to fit into a numpy array. So, we append vectors containing just \"0\"s into our words up to the maximum word size, such that all word embeddings have the same shape."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "max_line_size = max(len(x) for x in lines_tensors)\n",
    "\n",
    "def lineToTensorFillEmpty(line, max_line_size):\n",
    "    tensor = torch.zeros(max_line_size, 1, n_letters) #notice the difference between this method and the previous one\n",
    "    for li, letter in enumerate(line):\n",
    "        tensor[li][0][letterToIndex(letter)] = 1\n",
    "        \n",
    "        #Vectors with (0,0,.... ,0) are placed where there are no characters\n",
    "    return tensor\n",
    "\n",
    "def list_strings_to_list_tensors_fill_empty(names_list):\n",
    "    lines_tensors = []\n",
    "    for index, line in enumerate(names_list):\n",
    "        lineTensor = lineToTensorFillEmpty(line, max_line_size)\n",
    "        lines_tensors.append(lineTensor)\n",
    "    return(lines_tensors)\n",
    "\n",
    "lines_tensors = list_strings_to_list_tensors_fill_empty(names_list)\n",
    "\n",
    "#Let's take a look at what a word now looks like\n",
    "print(names_list[0])\n",
    "print(lines_tensors[0])\n",
    "print(lines_tensors[0].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#And finally, from a list, we can create a numpy array with all our word embeddings having the same shape:\n",
    "array_lines_tensors = np.stack(lines_tensors)\n",
    "#However, such operation introduces one extra dimension (look at the dimension with index=2 having size '1')\n",
    "print(array_lines_tensors.shape)\n",
    "#Because that dimension just has size 1, we can get rid of it with the following function call\n",
    "array_lines_proper_dimension = np.squeeze(array_lines_tensors, axis=2)\n",
    "print(array_lines_proper_dimension.shape)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data unbalancing and batch randomization:\n",
    "You may have noticed that our dataset is strongly unbalanced and contains a lot of data points in the \"russian.txt\" dataset. However, we would still like to take a random batch during our training procedure at every iteration. In order to prevent our neural network from classifying a data point as always belonging to the \"Russian\" category, we first pick a random category and then select a data point from that category. To do that, we construct a dictionary mapping a certain category to the corresponding starting index in the list of data points (e.g.: lines). Afterwards, we will take a datapoint starting from the starting_index identified"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_start_index_per_category(category_list):\n",
    "    categories_start_index = {}\n",
    "    \n",
    "    #Initialize every category with an empty list\n",
    "    for category in all_categories:\n",
    "        categories_start_index[category] = []\n",
    "    \n",
    "    #Insert the start index of each category into the dictionary categories_start_index\n",
    "    #Example: \"Italian\" --> 203\n",
    "    #         \"Spanish\" --> 19776\n",
    "    last_category = None\n",
    "    i = 0\n",
    "    for name in names_list:\n",
    "        cur_category = category_list[i]\n",
    "        if(cur_category != last_category):\n",
    "            categories_start_index[cur_category] = i\n",
    "            last_category = cur_category\n",
    "        \n",
    "        i = i + 1\n",
    "        \n",
    "    return(categories_start_index)\n",
    "\n",
    "categories_start_index = find_start_index_per_category(category_list)\n",
    "\n",
    "print(categories_start_index)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's define a few functions to take a random index from from the dataset, so that we'll be able to select a random data point and a random category."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def randomChoice(l):\n",
    "    rand_value = random.randint(0, len(l) - 1)\n",
    "    return l[rand_value], rand_value\n",
    "\n",
    "\n",
    "def randomTrainingIndex():\n",
    "    category, rand_cat_index = randomChoice(all_categories) #cat = category, it's not a random animal\n",
    "    #rand_line_index is a relative index for a data point within the random category rand_cat_index\n",
    "    line, rand_line_index = randomChoice(category_lines[category])\n",
    "    category_start_index = categories_start_index[category]\n",
    "    absolute_index = category_start_index + rand_line_index\n",
    "    return(absolute_index)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Step: Model - Recurrent Neural Network\n",
    "Hey, I must admit that was indeed a lot of data preprocessing and transformation, but it was well worth it! \n",
    "We have defined almost all the function we'll be needing during the training procedure and our data is ready\n",
    "to be fed into the neural network, which we're creating now:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Two hidden layers, based on simple linear layers\n",
    "\n",
    "class RNN(nn.Module):\n",
    "    def __init__(self, input_size, hidden_size, output_size):\n",
    "        super(RNN, self).__init__()\n",
    "        self.hidden_size = hidden_size\n",
    "        self.i2h = nn.Linear(input_size + hidden_size, hidden_size)\n",
    "        self.i2o = nn.Linear(input_size + hidden_size, output_size)\n",
    "        self.softmax = nn.LogSoftmax(dim=1)\n",
    "\n",
    "    def forward(self, input, hidden):\n",
    "        combined = torch.cat((input, hidden), 1)\n",
    "        hidden = self.i2h(combined)\n",
    "        output = self.i2o(combined)\n",
    "        output = self.softmax(output)\n",
    "        return output, hidden\n",
    "\n",
    "    def initHidden(self):\n",
    "        return torch.zeros(1, self.hidden_size)\n",
    "\n",
    "#Let's instantiate the neural network already:\n",
    "n_hidden = 128\n",
    "#Instantiate RNN\n",
    "\n",
    "device = torch.device(\"cuda\" if args.use_cuda else \"cpu\")\n",
    "model = RNN(n_letters, n_hidden, n_categories).to(device)\n",
    "#The final softmax layer will produce a probability for each one of our 18 categories\n",
    "print(model)    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Now let's define our workers. You can either use remote workers or virtual workers\n",
    "hook = sy.TorchHook(torch)  # <-- NEW: hook PyTorch ie add extra functionalities to support Federated Learning\n",
    "alice = sy.VirtualWorker(hook, id=\"alice\")  \n",
    "bob = sy.VirtualWorker(hook, id=\"bob\")  \n",
    "#charlie = sy.VirtualWorker(hook, id=\"charlie\") \n",
    "\n",
    "workers_virtual = [alice, bob]\n",
    "\n",
    "#If you have your workers operating remotely, like on Raspberry PIs\n",
    "#kwargs_websocket_alice = {\"host\": \"ip_alice\", \"hook\": hook}\n",
    "#alice = WebsocketClientWorker(id=\"alice\", port=8777, **kwargs_websocket_alice)\n",
    "#kwargs_websocket_bob = {\"host\": \"ip_bob\", \"hook\": hook}\n",
    "#bob = WebsocketClientWorker(id=\"bob\", port=8778, **kwargs_websocket_bob)\n",
    "#workers_virtual = [alice, bob]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#array_lines_proper_dimension = our data points(X)\n",
    "#categories_numpy = our labels (Y)\n",
    "langDataset =  LanguageDataset(array_lines_proper_dimension, categories_numpy)\n",
    "\n",
    "#assign the data points and the corresponding categories to workers.\n",
    "federated_train_loader = sy.FederatedDataLoader(\n",
    "            langDataset\n",
    "            .federate(workers_virtual),\n",
    "            batch_size=args.batch_size)  \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Step - Model Training!\n",
    "\n",
    "\n",
    "It's now time to train our Recurrent Neural Network based on the processed data. To do that, we need to define a few more functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def categoryFromOutput(output):\n",
    "    top_n, top_i = output.topk(1)\n",
    "    category_i = top_i[0].item()\n",
    "    return all_categories[category_i], category_i\n",
    "\n",
    "def timeSince(since):\n",
    "    now = time.time()\n",
    "    s = now - since\n",
    "    m = math.floor(s / 60)\n",
    "    s -= m * 60\n",
    "    return '%dm %ds' % (m, s)\n",
    "\n",
    "\n",
    "def fed_avg_every_n_iters(model_pointers, iter, federate_after_n_batches):\n",
    "        models_local = {}\n",
    "        \n",
    "        if(iter % args.federate_after_n_batches == 0):\n",
    "            for worker_name, model_pointer in model_pointers.items():\n",
    "#                #need to assign the model to the worker it belongs to.\n",
    "                models_local[worker_name] = model_pointer.copy().get()\n",
    "            model_avg = utils.federated_avg(models_local)\n",
    "           \n",
    "            for worker in workers_virtual:\n",
    "                model_copied_avg = model_avg.copy()\n",
    "                model_ptr = model_copied_avg.send(worker) \n",
    "                model_pointers[worker.id] = model_ptr\n",
    "                \n",
    "        return(model_pointers)     \n",
    "\n",
    "def fw_bw_pass_model(model_pointers, line_single, category_single):\n",
    "    #get the right initialized model\n",
    "    model_ptr = model_pointers[line_single.location.id]   \n",
    "    line_reshaped = line_single.reshape(max_line_size, 1, len(all_letters))\n",
    "    line_reshaped, category_single = line_reshaped.to(device), category_single.to(device)\n",
    "    #Firstly, initialize hidden layer\n",
    "    hidden_init = model_ptr.initHidden() \n",
    "    #And now zero grad the model\n",
    "    model_ptr.zero_grad()\n",
    "    hidden_ptr = hidden_init.send(line_single.location)\n",
    "    amount_lines_non_zero = len(torch.nonzero(line_reshaped.copy().get()))\n",
    "    #now need to perform forward passes\n",
    "    for i in range(amount_lines_non_zero): \n",
    "        output, hidden_ptr = model_ptr(line_reshaped[i], hidden_ptr) \n",
    "    criterion = nn.NLLLoss()   \n",
    "    loss = criterion(output, category_single) \n",
    "    loss.backward()\n",
    "    \n",
    "    model_got = model_ptr.get() \n",
    "    \n",
    "    #Perform model weights' updates    \n",
    "    for param in model_got.parameters():\n",
    "        param.data.add_(-args.learning_rate, param.grad.data)\n",
    "        \n",
    "        \n",
    "    model_sent = model_got.send(line_single.location.id)\n",
    "    model_pointers[line_single.location.id] = model_sent\n",
    "    \n",
    "    return(model_pointers, loss, output)\n",
    "            \n",
    "  \n",
    "    \n",
    "def train_RNN(n_iters, print_every, plot_every, federate_after_n_batches, list_federated_train_loader):\n",
    "    current_loss = 0\n",
    "    all_losses = []    \n",
    "    \n",
    "    model_pointers = {}\n",
    "    \n",
    "    #Send the initialized model to every single worker just before the training procedure starts\n",
    "    for worker in workers_virtual:\n",
    "        model_copied = model.copy()\n",
    "        model_ptr = model_copied.send(worker) \n",
    "        model_pointers[worker.id] = model_ptr\n",
    "\n",
    "    #extract a random element from the list and perform training on it\n",
    "    for iter in range(1, n_iters + 1):        \n",
    "        random_index = randomTrainingIndex()\n",
    "        line_single, category_single = list_federated_train_loader[random_index]\n",
    "        #print(category_single.copy().get())\n",
    "        line_name = names_list[random_index]\n",
    "        model_pointers, loss, output = fw_bw_pass_model(model_pointers, line_single, category_single)\n",
    "        #model_pointers = fed_avg_every_n_iters(model_pointers, iter, args.federate_after_n_batches)\n",
    "        #Update the current loss a\n",
    "        loss_got = loss.get().item() \n",
    "        current_loss += loss_got\n",
    "        \n",
    "        if iter % plot_every == 0:\n",
    "            all_losses.append(current_loss / plot_every)\n",
    "            current_loss = 0\n",
    "             \n",
    "        if(iter % print_every == 0):\n",
    "            output_got = output.get()  #Without copy()\n",
    "            guess, guess_i = categoryFromOutput(output_got)\n",
    "            category = all_categories[category_single.copy().get().item()]\n",
    "            correct = '✓' if guess == category else '✗ (%s)' % category\n",
    "            print('%d %d%% (%s) %.4f %s / %s %s' % (iter, iter / n_iters * 100, timeSince(start), loss_got, line_name, guess, correct))\n",
    "    return(all_losses, model_pointers)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In order for the defined randomization process to work, we need to wrap the data points and categories into a list, from that we're going to take a batch at a random index.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#This may take a few seconds to complete.\n",
    "print(\"Generating list of batches for the workers...\")\n",
    "list_federated_train_loader = list(federated_train_loader) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And finally,let's launch our training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "start = time.time()\n",
    "all_losses, model_pointers = train_RNN(args.epochs, args.print_every, args.plot_every, args.federate_after_n_batches, list_federated_train_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Let's plot the loss we got during the training procedure\n",
    "plt.figure()\n",
    "plt.ylabel(\"Loss\")\n",
    "plt.xlabel('Epochs (100s)')\n",
    "plt.plot(all_losses)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Step - Predict!\n",
    "Great! We have successfully created our two models for bob and alice in parallel using federated learning! I experimented with federated averaging of the two models, but it turned out that for a batch size of 1, as in the present case, the model loss was diverging. Let's try using our models for prediction now, shall we? This is the final reward for our endeavours."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict(model, input_line, worker, n_predictions=3):\n",
    "#     model = model.copy().get()\n",
    "    print('\\n> %s' % input_line)\n",
    "    model_remote = model.send(worker)\n",
    "    line_tensor = lineToTensor(input_line)\n",
    "    line_remote = line_tensor.copy().send(worker)\n",
    "    #line_tensor = lineToTensor(input_line)\n",
    "    #output = evaluate(model, line_remote)\n",
    "    # Get top N categories\n",
    "    hidden = model_remote.initHidden()\n",
    "    hidden_remote = hidden.copy().send(worker)\n",
    "        \n",
    "    with torch.no_grad():\n",
    "        for i in range(line_remote.shape[0]):\n",
    "            output, hidden_remote = model_remote(line_remote[i], hidden_remote)\n",
    "        \n",
    "    topv, topi = output.copy().get().topk(n_predictions, 1, True)\n",
    "    predictions = []\n",
    "\n",
    "    for i in range(n_predictions):\n",
    "        value = topv[0][i].item()\n",
    "        category_index = topi[0][i].item()\n",
    "        print('(%.2f) %s' % (value, all_categories[category_index]))\n",
    "        predictions.append([value, all_categories[category_index]])\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notice how the different models learned may perform different predictions, based on the data that was shown to them."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_alice = model_pointers[\"alice\"].get()\n",
    "model_bob = model_pointers[\"bob\"].get()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predict(model_alice.copy(), \"Qing\", alice) \n",
    "predict(model_alice.copy(), \"Daniele\", alice) \n",
    "\n",
    "predict(model_bob.copy(), \"Qing\", alice) \n",
    "predict(model_bob.copy(), \"Daniele\", alice) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You may try experimenting with this example right now, for example by increasing or decreasing the amount of epochs and seeing how the two models perform. You may also try to de-commenting the part about federating averaging and check the new resulting loss function. There can be lots of other optimizations we may think of as well!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Congratulations!!! - Time to Join the Community!\n",
    "\n",
    "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the movement toward privacy preserving, decentralized ownership of AI and the AI supply chain (data), you can do so in the following ways!\n",
    "\n",
    "### Star PySyft on GitHub\n",
    "\n",
    "The easiest way to help our community is just by starring the Repos! This helps raise awareness of the cool tools we're building.\n",
    "\n",
    "- [Star PySyft](https://github.com/OpenMined/PySyft)\n",
    "\n",
    "### Join our Slack!\n",
    "\n",
    "The best way to keep up to date on the latest advancements is to join our community! You can do so by filling out the form at [http://slack.openmined.org](http://slack.openmined.org)\n",
    "\n",
    "### Join a Code Project!\n",
    "\n",
    "The best way to contribute to our community is to become a code contributor! At any time you can go to PySyft GitHub Issues page and filter for \"Projects\". This will show you all the top level Tickets giving an overview of what projects you can join! If you don't want to join a project, but you would like to do a bit of coding, you can also look for more \"one off\" mini-projects by searching for GitHub issues marked \"good first issue\".\n",
    "\n",
    "- [PySyft Projects](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3AProject)\n",
    "- [Good First Issue Tickets](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
    "\n",
    "### Donate\n",
    "\n",
    "If you don't have time to contribute to our codebase, but would still like to lend support, you can also become a Backer on our Open Collective. All donations go toward our web hosting and other community expenses such as hackathons and meetups!\n",
    "\n",
    "[OpenMined's Open Collective Page](https://opencollective.com/openmined)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
